import argparse
from pathlib import Path

import kaldialign
from lhotse import CutSet

ARGPARSE_DESCRIPTION = """
This helper code takes in a disfluent recogs file generated from icefall.utils.store_transcript,
compares it against a fluent transcript, and saves the results in a separate directory.
This is useful to compare disfluent models with fluent models on the same metric.

"""


def get_args():
    parser = argparse.ArgumentParser(
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
        description=ARGPARSE_DESCRIPTION,
    )
    parser.add_argument(
        "--recogs",
        type=Path,
        required=True,
        help="Path to the recogs-XXX file generated by icefall.utils.store_transcript.",
    )
    parser.add_argument(
        "--cut",
        type=Path,
        required=True,
        help="Path to the cut manifest to be compared to. Assumes that disfluent_tag exists in the custom dict.",
    )
    parser.add_argument(
        "--res-dir", type=Path, required=True, help="Path to save results"
    )
    return parser.parse_args()


def d2f(stats):
    """
    Compare the outputs of a disfluent model against a fluent reference.
    Indicates a disfluent model's performance only on the content words

    CER^d_f = (sub_f + ins + del_f) / Nf

    """
    return stats["base"] / stats["Nf"]


def calc_cer(refs, hyps):
    subs = {
        "F": 0,
        "D": 0,
    }
    ins = 0
    dels = {
        "F": 0,
        "D": 0,
    }
    cors = {
        "F": 0,
        "D": 0,
    }
    dis_ref_len = 0
    flu_ref_len = 0

    for ref, hyp in zip(refs, hyps):
        assert (
            ref[0] == hyp[0]
        ), f"Expected ref cut id {ref[0]} to be the same as hyp cut id {hyp[0]}."
        tag = ref[2].copy()
        ref = ref[1]
        dis_ref_len += len(ref)
        # Remember that the 'D' and 'F' tags here refer to CSJ tags, not disfluent and fluent respectively.
        flu_ref_len += len([t for t in tag if ("D" not in t and "F" not in t)])
        hyp = hyp[1]
        ali = kaldialign.align(ref, hyp, "*")
        tags = ["*" if r[0] == "*" else tag.pop(0) for r in ali]
        for tag, (ref_word, hyp_word) in zip(tags, ali):
            if "D" in tag or "F" in tag:
                tag = "D"
            else:
                tag = "F"

            if ref_word == "*":
                ins += 1
            elif hyp_word == "*":
                dels[tag] += 1
            elif ref_word != hyp_word:
                subs[tag] += 1
            else:
                cors[tag] += 1

    return {
        "subs": subs,
        "ins": ins,
        "dels": dels,
        "cors": cors,
        "dis_ref_len": dis_ref_len,
        "flu_ref_len": flu_ref_len,
    }


def for_each_recogs(recogs_file: Path, refs, out_dir):
    hyps = []
    with recogs_file.open() as fin:
        for line in fin:
            if "ref" in line:
                continue
            cutid, hyp = line.split(":\thyp=")
            hyps.append((cutid, eval(hyp)))

    assert len(refs) == len(
        hyps
    ), f"Expected refs len {len(refs)} and hyps len {len(hyps)} to be equal."
    stats = calc_cer(refs, hyps)
    stat_table = ["tag,yes,no"]

    for cer_type in ["subs", "dels", "cors", "ins"]:
        ret = f"{cer_type}"
        for df in ["D", "F"]:
            try:
                ret += f",{stats[cer_type][df]}"
            except TypeError:
                # insertions do not belong to F or D, and is not subscriptable.
                ret += f",{stats[cer_type]},"
                break
        stat_table.append(ret)
    stat_table = "\n".join(stat_table)

    stats = {
        "subd": stats["subs"]["D"],
        "deld": stats["dels"]["D"],
        "cord": stats["cors"]["D"],
        "Nf": stats["flu_ref_len"],
        "base": stats["subs"]["F"] + stats["ins"] + stats["dels"]["F"],
    }

    cer = d2f(stats)
    results = [
        f"{cer:.2%}",
        f"Nf,{stats['Nf']}",
    ]
    results = "\n".join(results)

    with (out_dir / (recogs_file.stem + ".dfcer")).open("w") as fout:
        fout.write(results)
        fout.write("\n\n")
        fout.write(stat_table)


def main():
    args = get_args()
    recogs_file: Path = args.recogs
    assert (
        recogs_file.is_file() or recogs_file.is_dir()
    ), f"recogs_file cannot be found at {recogs_file}."

    args.res_dir.mkdir(parents=True, exist_ok=True)

    if recogs_file.is_file() and recogs_file.stem.startswith("recogs-"):
        assert (
            "csj_cuts" in args.cut.name
        ), f"Expected {args.cut} to be a cuts manifest."

        refs: CutSet = CutSet.from_file(args.cut)
        refs = sorted(
            [
                (
                    e.id,
                    list(e.supervisions[0].custom["disfluent"]),
                    e.supervisions[0].custom["disfluent_tag"].split(","),
                )
                for e in refs
            ],
            key=lambda x: x[0],
        )
        for_each_recogs(recogs_file, refs, args.res_dir)

    elif recogs_file.is_dir():
        recogs_file_path = recogs_file
        for partname in ["eval1", "eval2", "eval3", "excluded", "valid"]:
            refs: CutSet = CutSet.from_file(args.cut / f"csj_cuts_{partname}.jsonl.gz")
            refs = sorted(
                [
                    (
                        r.id,
                        list(r.supervisions[0].custom["disfluent"]),
                        r.supervisions[0].custom["disfluent_tag"].split(","),
                    )
                    for r in refs
                ],
                key=lambda x: x[0],
            )
            for recogs_file in recogs_file_path.glob(f"recogs-{partname}-*.txt"):
                for_each_recogs(recogs_file, refs, args.res_dir)

    else:
        raise TypeError(f"Unrecognised recogs file provided: {recogs_file}")


if __name__ == "__main__":
    main()
